{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Wideband Dataset\n",
    "This notebook showcases the Wideband dataset.\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define Variables\n",
    "\n",
    "num_iq_samples_dataset = 4096 # 64^2\n",
    "fft_size = 64\n",
    "impairment_level = 0 # clean"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Wideband Metadata\n",
    "In order to create a NewWideband dataset, you must define parameters in WidebandMetadata. This can be done either in code or inside a YAML file. Below we show how to do both. Look at `wideband_example.yaml` for a sample YAML file.\n",
    "\n",
    "There are four required parameters: \n",
    "1. `num_iq_samples_dataset` -> how much IQ data per sample\n",
    "2. `fft_size` -> Size of FFT (number of bins) to be used in spectrogram.\n",
    "3. `impairment_level` -> what environment impairment to simulate.\n",
    "4. `num_signals_max` -> maximum number of signals per sample.\n",
    "\n",
    "Additionally, there are several optional parameters that can be overriden."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Option 1: Instantiate WidebandMetadata object\n",
    "from torchsig.datasets.dataset_metadata import WidebandMetadata\n",
    "\n",
    "wideband_metadata_1 = WidebandMetadata(\n",
    "    num_iq_samples_dataset = num_iq_samples_dataset, # 64^2\n",
    "    fft_size = fft_size,\n",
    "    impairment_level = impairment_level, # clean\n",
    "    num_signals_max = 3,\n",
    ")\n",
    "print(wideband_metadata_1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Option 2: Instantiate as a dictionary object\n",
    "\n",
    "wideband_metadata_2 = dict(\n",
    "    num_iq_samples = num_iq_samples_dataset,\n",
    "    fft_size = fft_size,\n",
    "    impairment_level = impairment_level,\n",
    "    num_signals_max = 3,\n",
    ")\n",
    "print(wideband_metadata_2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Option 3: Instantiate from YAML file\n",
    "# see wideband_example.yaml\n",
    "\n",
    "wideband_metadata_3 = \"wideband_example.yaml\"\n",
    "print(wideband_metadata_3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Infinite Dataset\n",
    "To create an infinite dataset, simply instantiate a NewWideband object. Ensure the `num_samples_dataset` field inside the WidebandMetadata is set to None.\n",
    "\n",
    "Notes:\n",
    "* You will not be able to access previously generated samples.\n",
    "* You cannot save an infinite dataset to disk."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchsig.datasets.dataset_metadata import WidebandMetadata\n",
    "from torchsig.datasets.wideband import NewWideband\n",
    "\n",
    "wideband_infinite_metadata = WidebandMetadata(\n",
    "    num_iq_samples_dataset = num_iq_samples_dataset,\n",
    "    fft_size = fft_size,\n",
    "    impairment_level = impairment_level,\n",
    "    num_signals_max = 3,\n",
    "    num_samples = None,\n",
    ")\n",
    "\n",
    "wideband_infinite = NewWideband(\n",
    "    dataset_metadata = wideband_infinite_metadata\n",
    ")\n",
    "# print(wideband_infinite)\n",
    "\n",
    "for i in range(10):\n",
    "    data, metadata = wideband_infinite[i]\n",
    "    if i == 5:\n",
    "        print(f\"IQ Data: {data.shape}\")\n",
    "        print(f\"Signal Metadata: {metadata}\")\n",
    "\n",
    "# should fail\n",
    "print()\n",
    "try:\n",
    "    wideband_infinite[0]\n",
    "except Exception as E:\n",
    "    print(E)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Finite Dataset\n",
    "To create a finite dataset, set `num_samples` = any positive integer number.\n",
    "\n",
    "Note: you will still not be able to access previously generated samples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchsig.datasets.dataset_metadata import WidebandMetadata\n",
    "from torchsig.datasets.wideband import NewWideband\n",
    "\n",
    "num_samples = 10\n",
    "\n",
    "wideband_finite_metadata = WidebandMetadata(\n",
    "    num_iq_samples_dataset = num_iq_samples_dataset,\n",
    "    fft_size = fft_size,\n",
    "    impairment_level = impairment_level,\n",
    "    num_signals_max = 3,\n",
    "    num_samples = num_samples,\n",
    ")\n",
    "\n",
    "wideband_finite = NewWideband(\n",
    "    dataset_metadata = wideband_finite_metadata\n",
    ")\n",
    "# print(wideband_finite)\n",
    "\n",
    "for i in range(11):\n",
    "    if i == 10: # should fail\n",
    "        try:\n",
    "            data, metadata = wideband_finite[i]\n",
    "        except Exception as E:\n",
    "            print(E)\n",
    "    else:\n",
    "        data, metadata = wideband_finite[i]\n",
    "        print(f\"sample {i} generated successfully\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Writing Dataset to Disk\n",
    "In order to access previosuly generated examples, or save the finite dataset for later, use the `DatasetCreator`. Pass in the Dataset to be saved, where to write the dataset (root), and whether to overwrite any exisitng datasets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchsig.datasets.dataset_metadata import WidebandMetadata\n",
    "from torchsig.datasets.wideband import NewWideband\n",
    "from torchsig.utils.writer import DatasetCreator\n",
    "from torchsig.signals.signal_lists import TorchSigSignalLists\n",
    "\n",
    "root = \"./datasets/wideband_example\"\n",
    "class_list = TorchSigSignalLists.all_signals\n",
    "num_samples = len(class_list) * 10\n",
    "seed = 123456789\n",
    "\n",
    "wideband_finite_metadata = WidebandMetadata(\n",
    "    num_iq_samples_dataset = num_iq_samples_dataset,\n",
    "    fft_size = fft_size,\n",
    "    impairment_level = impairment_level,\n",
    "    num_signals_max = 3,\n",
    "    num_samples = num_samples,\n",
    "    seed = seed,\n",
    ")\n",
    "\n",
    "wideband_finite = NewWideband(\n",
    "    dataset_metadata = wideband_finite_metadata\n",
    ")\n",
    "\n",
    "dataset_creator = DatasetCreator(\n",
    "    dataset = wideband_finite,\n",
    "    root = root,\n",
    "    # overwrite = True,\n",
    "    # increase the batch size and num_workers to speed up\n",
    "    # batch_size = 16,\n",
    "    # num_workers = 16\n",
    ")\n",
    "\n",
    "dataset_creator.create()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Reading Dataset from Disk\n",
    "Assuming you wrote a dataset to disk, you can load it back in my instantiating a `StaticWideband`.\n",
    "\n",
    "Warning: The following code assumes you have the code in the section **Writing Dataset to Disk**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchsig.datasets.wideband import StaticWideband\n",
    "\n",
    "static_wideband = StaticWideband(\n",
    "    root = root,\n",
    "    impaired = impairment_level > 0,\n",
    ")\n",
    "\n",
    "# can access any sample\n",
    "print(static_wideband[0])\n",
    "print(static_wideband[5])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If the dataset writtent to disk is raw (aka no Transforms or Target Transforms were applied to it before writing to disk), then you can define whatever transforms and target transforms inside the StaticWideband."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Dataset Raw: {static_wideband.raw}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "from torchsig.datasets.wideband import StaticWideband\n",
    "from torchsig.transforms.dataset_transforms import Spectrogram\n",
    "from torchsig.transforms.target_transforms import (\n",
    "    ClassName,\n",
    "    Start,\n",
    "    Stop,\n",
    "    LowerFreq,\n",
    "    UpperFreq,\n",
    "    SNR\n",
    ")\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "target_transform = [\n",
    "    ClassName(),\n",
    "    Start(),\n",
    "    Stop(),\n",
    "    LowerFreq(),\n",
    "    UpperFreq(),\n",
    "    SNR()\n",
    "]\n",
    "\n",
    "static_wideband = StaticWideband(\n",
    "    root = root,\n",
    "    impaired = impairment_level > 0,\n",
    "    transforms = Spectrogram(fft_size = fft_size),\n",
    "    target_transforms = target_transform,\n",
    ")\n",
    "sample_rate = static_wideband.dataset_metadata.sample_rate\n",
    "noise_power_db = static_wideband.dataset_metadata.noise_power_db\n",
    "\n",
    "num_show = 4\n",
    "num_cols = 2\n",
    "num_rows = num_show // num_cols\n",
    "\n",
    "fig = plt.figure(figsize=(18,12))\n",
    "fig.tight_layout()\n",
    "\n",
    "xmin = 0\n",
    "xmax = 1\n",
    "ymin = -sample_rate / 2\n",
    "ymax = sample_rate / 2\n",
    "\n",
    "for i in range(num_show):\n",
    "\n",
    "    data, targets = static_wideband[i]\n",
    "    ax = fig.add_subplot(num_rows,num_cols,i + 1)\n",
    "\n",
    "    pos = ax.imshow(data,extent=[xmin,xmax,ymin,ymax],aspect='auto',cmap='Wistia',vmin=noise_power_db)\n",
    "    fig.colorbar(pos, ax=ax)\n",
    "\n",
    "    for t in targets:\n",
    "        classname, start, stop, lower, upper, snr = t\n",
    "\n",
    "\n",
    "        ax.plot([start,start],[lower,upper],'b',alpha=0.5)\n",
    "        ax.plot([stop, stop],[lower,upper],'b',alpha=0.5)\n",
    "        ax.plot([start,stop],[lower,lower],'b',alpha=0.5)\n",
    "        ax.plot([start,stop],[upper,upper],'b',alpha=0.5)\n",
    "        textDisplay = str(classname) + ', SNR = ' + str(snr) + ' dB'\n",
    "        ax.text(start,lower,textDisplay, bbox=dict(facecolor='w', alpha=0.5, linewidth=0))\n",
    "        ax.set_xlim([0,1])\n",
    "        ax.set_ylim([-sample_rate/2,sample_rate/2])\n",
    "\n",
    "        # fig.suptitle(f\"class: {classname}\", fontsize=16)\n",
    "        # plt.ylabel(\"Frequency (Hz)\")\n",
    "        # plt.xlabel(\"Time\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Wideband Dataset Statistics\n",
    "\n",
    "Below are some plots and statistics about the Wideband dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "class_list = static_wideband.dataset_metadata.class_list\n",
    "class_counter = {class_name: 0 for class_name in class_list}\n",
    "snr_list = []\n",
    "num_signals_per_sample = []\n",
    "\n",
    "for sample in tqdm(static_wideband, desc = \"Calculating Wideband Stats\"):\n",
    "    data, targets = sample\n",
    "    num_signals_per_sample.append(len(targets))\n",
    "    for t in targets:\n",
    "        classname, start, stop, lower, upper, snr = t\n",
    "        class_counter[classname] += 1\n",
    "        snr_list.append(snr)\n",
    "\n",
    "class_counts = list(class_counter.values())\n",
    "class_names = list(class_counter.keys())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Class Distribution Pie Chart\n",
    "# by default, the class distribution is None aka uniform\n",
    "print(f\"Class Distribution Setting: {static_wideband.dataset_metadata.class_distribution}\")\n",
    "plt.figure(figsize=(15, 15))\n",
    "plt.pie(class_counts, labels = class_names)\n",
    "plt.title(\"Class Distribution Pie Chart\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Class Distribution Bar Chart\n",
    "plt.figure(figsize=(18, 4))\n",
    "plt.bar(class_names, class_counts)\n",
    "plt.xticks(rotation=90)\n",
    "plt.title(\"Class Distribution Bar Chart\")\n",
    "plt.xlabel(\"Modulation Class Name\")\n",
    "plt.ylabel(\"Counts\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Number of signals per sample Distribution\n",
    "import numpy as np\n",
    "# Wideband's default number of signals per sample is 3\n",
    "print(f\"Min num signals Setting: {static_wideband.dataset_metadata.num_signals_min}\")\n",
    "print(f\"Max num signals Setting: {static_wideband.dataset_metadata.num_signals_max}\")\n",
    "\n",
    "total = sum(num_signals_per_sample)\n",
    "avg = np.mean(np.asarray(num_signals_per_sample))\n",
    "\n",
    "plt.figure(figsize=(11,8))\n",
    "plt.hist(x=num_signals_per_sample, bins=np.arange(max(num_signals_per_sample) + 1) + 0.5)\n",
    "plt.xticks(np.arange(0, max(num_signals_per_sample) + 1, 1).tolist())\n",
    "plt.title(f\"Distribution of Number of Signals Per Sample\\nTotal Number: {total}, Average: {avg}\")\n",
    "plt.xlabel(\"Number of Signal Bins\")\n",
    "plt.ylabel(\"Counts\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# SNR Distributions\n",
    "# Default min = 0 db and max = 50 db\n",
    "print(f\"Min SNR Setting: {static_wideband.dataset_metadata.snr_db_min}\")\n",
    "print(f\"Max SNR Setting: {static_wideband.dataset_metadata.snr_db_max}\")\n",
    "plt.figure(figsize=(11, 4))\n",
    "plt.hist(x=snr_list, bins=100)\n",
    "plt.title(\"SNR Distribution\")\n",
    "plt.xlabel(\"SNR Bins (dB)\")\n",
    "plt.ylabel(\"Counts\")\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "formats": "ipynb,py:percent"
  },
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
